{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6.10 Adam算法\n",
    "话说“天下大势合久必分，分久必合”，深度学习优化算法的比拼历程就精彩的上演着同样的故事。前面几节，我们从梯度下降开始，介绍了它的两种变体，随机梯度下降、小批量随机梯度下降，然后有介绍了改进梯度计算的动量法，以及改善学习率的Adagrad算法、RMSProp算法和Adadelta算法。有没有人琢磨着把它们的优点结合起来搞个大杂烩呢？必然的，这就是由Diederik Kingma和Jimmy Ba在2014年提出的Adam算法了。本节咱们就来详细介绍。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.10.1 基本原理\n",
    "\n",
    "Adam算法是在RMSProp算法的基础上提出的，并且使用了指数加权平均数来调整学习率。Adam算法被广泛用于神经网络的训练过程中，因为它能够自适应学习率，使得训练过程更加顺畅。Adam算法在传统梯度下降算法的基础上具体是怎么改进的呢，咱们来看它的数学公式：\n",
    "$$m_t = \\beta_1 m_{t-1} + (1 - \\beta_1)g_t$$\n",
    "$$v_t = \\beta_2 v_{t-1} + (1 - \\beta_2)g_t^2$$\n",
    "$$\\hat{m_t} = \\frac{m_t}{1 - \\beta_1^t}$$\n",
    "$$\\hat{v_t} = \\frac{v_t}{1 - \\beta_2^t}$$\n",
    "$$\\Delta w_t = -\\frac{\\alpha \\hat{m_t}}{\\sqrt{\\hat{v_t}} + \\epsilon}$$\n",
    "\n",
    "其中，$m_t$ 和 $v_t$ 分别表示梯度的一阶积分和二阶积分，$\\hat{m_t}$ 和 $\\hat{v_t}$ 分别表示 $m_t$ 和 $v_t$ 的无偏估计，$\\alpha$ 表示学习率，$\\beta_1$ 和 $\\beta_2$ 分别表示指数加权平均数的衰减率，$\\epsilon$ 是一个很小的正数，用于防止分母为 $0$ 的情况。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.10.2 算法流程\n",
    "\n",
    "在使用Adam算法训练神经网络时，我们需要设置一些参数，包括学习率 $\\alpha$，指数加权平均数的衰减率 $\\beta_1$ 和 $\\beta_2$，以及一个很小的正数 $\\epsilon$。我们可以使用如下的步骤来训练一个神经网络：\n",
    "\n",
    "1. 初始化网络的权重和偏置，并定义损失函数和Adam优化器。\n",
    "2. 在训练数据上进行前向传播。\n",
    "3. 计算损失。\n",
    "4. 进行反向传播。\n",
    "5. 使用Adam优化器更新权重和偏置。\n",
    "6. 重复步骤 2-5，直到达到规定的训练步数\n",
    "\n",
    "### 6.10.3 代码示例\n",
    "\n",
    "我们用一个例子演示在 PyTorch 中实现 Adam 算法并可视化训练过程："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8+yak3AAAACXBIWXMAAAsTAAALEwEAmpwYAAAcD0lEQVR4nO3deZCcd53f8fe3jzl67lPnSKPDli2MQWZ8YBMX2CZlwAFSQIIDrClcpWRDsrBFFQVha6lUqlJL7RYsW2zYcsAcCWXAXoclZDmMTQI2XtvjA1mWbEuydYzO0Uijkebs45s/+mlpPJas0XTPPP08/XlVTXU/Tz/dz/fxI3+eX/+e3/O0uTsiIhI9ibALEBGRhVGAi4hElAJcRCSiFOAiIhGlABcRiajUUq6su7vb+/v7K/Z5uYKz8/AYK9sa6Wquq9jniohUk6effvq4u/fMnb+kAd7f38/g4GDFPs/dedOXf8m/vraPL/+LN1Xsc0VEqomZ7Tvf/Eh3oZgZa7ua2DcyEXYpIiJLLtIBDtDflWHvyHjYZYiILLnIB/jariYOnJggX9AVpSJSWyIf4P1dGbJ559DoZNiliIgsqegHeHcTgLpRRKTmRD/Au0oBrhOZIlJbIh/gvS31NKQT7DuuFriI1JbIB3giYaztbFILXERqTuQDHGCthhKKSA2KRYCv625i/4iGEopIbblogJvZvWZ2zMy2n+e1z5mZm1n34pQ3P2u7mpjJFzgyNhVmGSIiS2o+LfDvArfPnWlmfcA/B/ZXuKZL1t+VAdCJTBGpKRcNcHf/LXDiPC99Dfg8EHq/xdpuDSUUkdqzoD5wM/sAcNDd/zCPZbea2aCZDQ4PDy9kdRe1orWBulSCfTqRKSI15JID3MwywH8C/nw+y7v7Pe4+4O4DPT2vu51tRRSHEmZ4VV0oIlJDFtIC3wCsA/5gZnuB1cAzZra8koVdKt1WVkRqzSX/oIO7Pw/0lqaDEB9w9+MVrOuS9XdleHT3MIWCk0hYmKWIiCyJ+QwjvA94HNhkZkNmdvfil3Xp1nY3MZUtcPS0hhKKSG24aAvc3e+8yOv9FaumDOuDkSivHh9nRVtjyNWIiCy+WFyJCcWrMQFeGdaJTBGpDbEJ8OWtDTSmkxqJIiI1IzYBnkgY/d1NvDJ8JuxSRESWRGwCHIr94GqBi0itiFeA9zRx4OQkM7lC2KWIiCy6WAX4uu4m8gVn/wld0CMi8RerAF/f0wygfnARqQmxCvB1s8aCi4jEXawCvK0xTXdzncaCi0hNiFWAQ7EVrha4iNSCWAb4KwpwEakBsQvw9T3NHD8zzdhUNuxSREQWVewC/OyJTPWDi0jMxS7AN/QEN7U6rqGEIhJvsQvwvs4MCVMLXETiL3YBXp9K0teZYY9OZIpIzMUuwCEYSqgWuIjEXHwD/Pg4hYKHXYqIyKKJZYCv72lmMpvX72OKSKzFMsA3BEMJ9xxTN4qIxNd8fpX+XjM7ZmbbZ837SzN70cy2mdn/MrP2Ra3yEm3oLd6VcPex0yFXIiKyeObTAv8ucPuceQ8BV7n71cDLwBcrXFdZelvqaalPsVu3lRWRGLtogLv7b4ETc+b9yt1zweQ/AasXobYFMzM29Daz+5gCXETiqxJ94J8Cfn6hF81sq5kNmtng8PBwBVY3Pxt7m9mtPnARibGyAtzMvgTkgB9caBl3v8fdB9x9oKenp5zVXZKNvcWbWp2a0E2tRCSeFhzgZvZJ4A7gY+5edQOuLyudyBzWiUwRiacFBbiZ3Q58Hni/u1flLwhvPDsSRf3gIhJP8xlGeB/wOLDJzIbM7G7gG0AL8JCZPWdmf7fIdV6y1R0Z6lIJBbiIxFbqYgu4+53nmf3tRailopIJY313kwJcRGIrlldilmzsbdZYcBGJrdgH+NDJSaay+bBLERGpuNgHuDvsUStcRGIo9gEOGokiIvEU6wBf191EwmCPAlxEYijWAV6fSrKmM6MTmSISS7EOcCjdE0UBLiLxE/sA39DbzKvHx8nlC2GXIiJSUbEP8I09zWTzzv4TVXnFv4jIgsU/wIORKLvUjSIiMRP7AL9sWQsAu47qroQiEi+xD/Dm+hSrOxp56aha4CISL7EPcIBNy1p4+Yha4CISLzUR4Jcvb2HP8BlmchqJIiLxURMBvmlZC7mCs3dEv5EpIvFREwF+eXAi8yV1o4hIjNREgK/vaSKZMI1EEZFYqYkAb0gn6e/K8JICXERipCYCHGDT8hZe1lBCEYmRmgnwy3pb2Dsyrl/nEZHYmM+v0t9rZsfMbPuseZ1m9pCZ7QoeOxa3zPJtWt6Cu37cQUTiYz4t8O8Ct8+Z9wXgYXe/DHg4mK5qGokiInFz0QB3998CJ+bM/gDwveD594APVrasyuvvylCXTPCyTmSKSEwstA98mbsfDp4fAZZdaEEz22pmg2Y2ODw8vMDVlS+VTLCht1kjUUQkNso+ienuDvgbvH6Puw+4+0BPT0+5qyvLpmXNuieKiMTGQgP8qJmtAAgej1WupMVz+fIWDp2aYmwqG3YpIiJlW2iA/xS4K3h+F/APlSlncW06e29wjUQRkeibzzDC+4DHgU1mNmRmdwN/AbzbzHYBtwXTVa80EuXFI2MhVyIiUr7UxRZw9zsv8NKtFa5l0a3uaKSlIcXOwwpwEYm+mrkSE8DMuHJ5KzsP60SmiERfTQU4wOaVrew8PEahcMGBMyIikVBzAX7lihYmZvLsPzERdikiImWpwQBvBVA/uIhEXs0F+OXLWkgmjB0KcBGJuJoL8IZ0kvXdTWqBi0jk1VyAQ7EbRSNRRCTqajLAN69s5eDoJKMTM2GXIiKyYDUZ4OdOZKoVLiLRVaMBXrykXicyRSTKajLAe1sa6G6u14lMEYm0mgxwKLbCFeAiEmU1G+CbV7ay6+gZsvlC2KWIiCxI7Qb4ilZm8gX2DOve4CISTTUb4KWRKDsOqRtFRKKpZgN8fXcTDekELyjARSSiajbAU8kEV65o5fmhU2GXIiKyIDUb4ABvXtXGC4dO6d7gIhJJNR/g4zN5Xjk+HnYpIiKXrLYDfHUbANsPqhtFRKKnrAA3sz81sxfMbLuZ3WdmDZUqbCls7GmmIZ1gm/rBRSSCFhzgZrYK+BNgwN2vApLARytV2FIonchUC1xEoqjcLpQU0GhmKSADHCq/pKV1tU5kikhELTjA3f0g8FfAfuAwcMrdfzV3OTPbamaDZjY4PDy88EoXyVU6kSkiEVVOF0oH8AFgHbASaDKzj89dzt3vcfcBdx/o6elZeKWLpHQi8/mDo+EWIiJyicrpQrkNeNXdh909CzwI3FiZspZO6UTm80O6IlNEoqWcAN8P3GBmGTMz4FZgZ2XKWjqpZILNOpEpIhFUTh/4E8ADwDPA88Fn3VOhupZU6YrMvE5kikiElDUKxd2/7O5XuPtV7v4Jd5+uVGFLqXQi89XjurWsiERHTV+JWXLuRKa6UUQkOhTgFE9kNqaT/OGAAlxEokMBTvFE5ptXt/HcgdGwSxERmTcFeGDLmnZ2HBpjOpcPuxQRkXlRgAe29HUwky/oF3pEJDIU4IEta9oBeHb/aKh1iIjMlwI8sKy1gVXtjTy7/2TYpYiIzIsCfJa39rWrBS4ikaEAn2XLmnYOjk5ybGwq7FJERC5KAT7L2X5wDScUkQhQgM/yppVtpJOmbhQRiQQF+CwN6SSbV7TqRKaIRIICfI4tazrYNnSKXL4QdikiIm9IAT7HljXtTGbzvHT0dNiliIi8IQX4HFv6OgBd0CMi1U8BPkdfZyNdTXU8s0/94CJS3RTgc5gZ1/Z38uTeE2GXIiLyhhTg53Htuk6GTk5y+NRk2KWIiFyQAvw8ruvvBODJV9UKF5HqpQA/jytXtNBcn1KAi0hVKyvAzazdzB4wsxfNbKeZvb1ShYUplUxwzdoOnlI/uIhUsXJb4F8HfuHuVwBvAXaWX1J1uK6/g5ePnuHk+EzYpYiInNeCA9zM2oCbgW8DuPuMu49WqK7QXbeuC0CtcBGpWuW0wNcBw8B3zOxZM/uWmTXNXcjMtprZoJkNDg8Pl7G6pXX16jbqkgkFuIhUrXICPAVcA3zT3bcA48AX5i7k7ve4+4C7D/T09JSxuqXVkE7ylr42ncgUkapVToAPAUPu/kQw/QDFQI+N69Z1sv3QGOPTubBLERF5nQUHuLsfAQ6Y2aZg1q3AjopUVSWu7e8kX3DdF0VEqlK5o1D+I/ADM9sGvBX4r2VXVEXetraDhMGTr46EXYqIyOukynmzuz8HDFSmlOrT0pDmTSvbePwVBbiIVB9diXkRN23s5tn9o+oHF5GqowC/iHds7CZXcN2dUESqjgL8Igb6O6hLJXhs1/GwSxEReQ0F+EU0pJNc29/Bo7sV4CJSXRTg83DTxm5ePHKa4dPTYZciInKWAnwe3rGxG4Df71ErXESqhwJ8Ht60so22xjSPqRtFRKqIAnwekgnj7eu7eGz3CO4edjkiIoACfN5uuqybg6OT7BuZCLsUERFAAT5vpX5wjUYRkWqhAJ+n/q4Mq9obeVTjwUWkSijA58nMuPnyHh7dfZyZXCHsckREFOCX4tYrejkznWNQl9WLSBVQgF+CGzd2UZdK8PCLx8IuRUREAX4pMnUpbtzQxSMKcBGpAgrwS3TrFb28enycV4bPhF2KiNQ4BfgletcVvQBqhYtI6BTgl2h1R4ZNy1p4eKcCXETCpQBfgFuu7OWpvScYm8qGXYqI1DAF+ALcckUvuYLzu5d1UY+IhKfsADezpJk9a2Y/q0RBUbClr532TJqHdx4NuxQRqWGVaIF/BthZgc+JjFQywbs29fLIS8fI5nVVpoiEo6wAN7PVwPuAb1WmnOh475tXMDqR5fd7RsIuRURqVLkt8L8GPg9csBlqZlvNbNDMBoeHh8tcXfX4Z5d101yf4v9sOxR2KSJSoxYc4GZ2B3DM3Z9+o+Xc/R53H3D3gZ6enoWuruo0pJPcdmUvv9pxVN0oIhKKclrgNwHvN7O9wA+BW8zsf1akqoh439Ur1Y0iIqFZcIC7+xfdfbW79wMfBR5x949XrLIIKHWj/OO2w2GXIiI1SOPAy9CQTvLuzcv45Y4j6kYRkSVXkQB39//r7ndU4rOiRqNRRCQsaoGXSd0oIhIWBXiZSt0oP99+mKlsPuxyRKSGKMAr4EPXrGZsKsevdWm9iCwhBXgF3Lihi1Xtjfx4cCjsUkSkhijAKyCRMD50zSp+t2uYw6cmwy5HRGqEArxCPvy2PtzhwWcOhl2KiNQIBXiFrOnKcP26Tu4fPIC7h12OiNQABXgFfWSgj70jEwzuOxl2KSJSAxTgFfTeNy+nqS7J/YMHwi5FRGqAAryCMnUp7rh6JT/bdphTk/q9TBFZXArwCvvE29cyMZNXK1xEFp0CvMKuWtXGdes6+c5je8npBlcisogU4IvgUzet4+DopK7MFJFFpQBfBO/evIzVHY3c++jesEsRkRhTgC+CZML45I39PLn3BM8PnQq7HBGJKQX4IvlX1/bRVJfkO4+9GnYpIhJTCvBF0tqQ5iMDffzvbYc4OKr7o4hI5SnAF9HWm9djGN94ZHfYpYhIDCnAF9HK9kbuvK6P+wcPsH9kIuxyRCRmFOCL7N+/ayPJhPE3j+wKuxQRiZkFB7iZ9ZnZb8xsh5m9YGafqWRhcbGstYGP37CWB58Z4pXhM2GXIyIxUk4LPAd8zt03AzcAnzazzZUpK17++J0bqE8l+frDaoWLSOUsOMDd/bC7PxM8Pw3sBFZVqrA46W6u564b+/npHw6x/aDGhYtIZVSkD9zM+oEtwBPneW2rmQ2a2eDw8HAlVhdJf/zODXQ11fNnP9lOoaAffBCR8pUd4GbWDPw98Fl3H5v7urvf4+4D7j7Q09NT7uoiq60xzZfedwXPHRjlR7pToYhUQFkBbmZpiuH9A3d/sDIlxdcH37qK69d18pVfvMiJ8ZmwyxGRiCtnFIoB3wZ2uvtXK1dSfJkZ/+WDV3FmKsdXfv5i2OWISMSV0wK/CfgEcIuZPRf8vbdCdcXW5ctauPsd6/jR4AEe23087HJEJMLKGYXyqLubu1/t7m8N/v6xksXF1Wdvu5yNvc189kfPcfzMdNjliEhE6UrMEDTWJfnGv9nCqcksn/vxHzQqRUQWRAEekiuWt/Lnd2zm/708zH//3SthlyMiEaQAD9HHrl/De65azl/+8iUe3zMSdjkiEjEK8BCZGX/xoavp725i6/cHeeGQrtIUkflTgIesrTHN9z91HS0NKe669yn2jYyHXZKIRIQCvAqsbG/k+3dfR65Q4I/ufZIjp6bCLklEIkABXiU29rbwnU9ey/HT0/zL//YYOw697q4EIiKvoQCvIlvWdHD/v7sRd/jI3/2e37x0LOySRKSKKcCrzOaVrfzk0zfR393E3d99iq//ehczuULYZYlIFVKAV6HlbQ38+N++nTuuXsnXfv0y7//Go2wbGg27LBGpMua+dFcBDgwM+ODg4JKtLw4e2nGUP/vJ8wyfnubDb1vN1ps3sLG3ObR6JmfyHB2bKv6dnubk+AwnJ2YYnchyZjrH5Eye8ZkcubyTKxQoFACDVMJIJRPUJRM01SfJ1CVpaUjT1lj862yqo7u5nq7mOnpb6mlpSIe2jSLVxsyedveBufNTYRQj8/fuzcu4fn0nX/3Vy/zwqf3c//QQ775yGXdev4YbN3RRn0pWdH2FgnNkbIq9I+PsG5lg38gE+0+MM3RykqGTkxe8DW5LQ4qW+hSNdUkydSnqUgmSZiSC73gzuQLjM3mms3kms3kmZvKcmcoxmc2f9/Oa6pIsa2tgRVsDq9obWdneyKr2Rvo6M/R1Zlje2kAyYRXddpGoUQs8QkbOTPO9x/fx/cf3MjqRpakuyc2X93DD+i42LW/hiuUttGfq3vAz3J1Tk1kOjU5x+NQkB0cn2T8ywb4TE8HjOFPZc33u6aTR15FhdWeGVe2NrO5oZHlrA8taG+htraezqY72xjSp5MJ646ayecYms4yMz3D8zDTHz0xzbGyaI0Er/+DoFIdGJxk+/dqbfqUSxqqORvo6MkGoN7KmM3N2uiOTpnjHY5Hou1ALXAEeQVPZPI+/MsJDO47y8M6jHB07F24t9SnaMmnaM2nqU0nyBafgzuRMnpMTWU5NzpDNv3afN6QTrOnMsKYzQ39XE/3dTcFjhhVtjVXR0p3O5Tk0OsXQyQkOnJjkwMkJDpyY4MDJSfaPjHNyIvua5ZvqkqzuyLCqo9hyX9VRbMWvbGtgRXsjvS31pBd40BFZagrwmHJ3jo5N8+KRMV46cpojY1OcmsgyOpllOpcnmUiQNGhIJ2nP1NGeSdPVVMeKtkZWthe7J3pa6iPfWj0znePAiQn2n5gIunuKQX9wdJKDJycYm8q9Znkz6Gqqo7el+E2ip7menpbiN4qu5jo6m+rpyBT759sb62hpSJGoggOZ1CYFuNS0saksh0enOHRqksOjxe6ZY6enOHJqiuEz0xw/XezCyb3BrX2b61O0NKRork+RqU/RFPT3N9YlaUwnaEgnqU8lqE8lqUslSCcTwaORTiZIJYqPyYSRShjJ4C9RmjbDrDQfEsHzhM1aNniemvW+hAXvTxY/IxmsJ2FE/sAsRTqJKTWttSFN6/I0m5a3XHAZd2dsMseJiRlOjE8zOpFldCLLyYkZxqZynJ7KcmYqx+mpHBPZPBPTOU5OTJ49MTuZzTOTKzCVzVMtt3hPJ4NATyRIJo1UonhASQUHlXQiQTpVnF+XTJybnzx34CkeiErvfe38dKr42alkaZRRsFwqQToYeZRKButPWLDuRDAqqbhs6WBWOriV/lJzHnUwej0FuEjAzGjLpGnLpFnX3VTWZ+XyBbJ5ZyZfYCZXIFcokMs72XyBfMHJu5PL+9nn+YKfPV9RKEDenULh3OuFWcsVgvcW3MkXIFconH1/vuDkSo/5AtngeTZfCIZ2FufngnnZoM7S8+lcgfHp3Nl5uYIzkyswky8U35d3poPnS32QShjnAt6K30CSs76FlL6hpJKlbzNc8BtMaYRUcs57i3+c/WZU+kZjxqzn574hnXsvJEqfFdSWCN5Tmv++N69gTVemov9NFOAii6DY8oRGKjvMs5rkZx0ESgenbOnAERwYcnknWyicvS7gtY/nDmjFA0twkAoOHAWfNX/WdGHWQaow6+A392DozgUPgPlC8UBZWr97aTnOfeas95Xmlw6a554Xlyn4uYNu8bXX//favLJVAS4i1aHYqk3SkI7vQaocZ8M/OACkk5XvAiprHJWZ3W5mL5nZbjP7QqWKEhGJukTQr1+fStJYl1zwtRJvuI6FvtHMksDfAu8BNgN3mtnmShUmIiJvrJxDwnXAbnd/xd1ngB8CH6hMWSIicjHl9IGvAg7Mmh4Crp+7kJltBbYGk2fM7KUFrq8bOL7A90ZZLW53LW4z1OZ21+I2w6Vv99rzzVz0k5jufg9wT7mfY2aD5xvIHne1uN21uM1Qm9tdi9sMldvucrpQDgJ9s6ZXB/NERGQJlBPgTwGXmdk6M6sDPgr8tDJliYjIxSy4C8Xdc2b2H4BfAkngXnd/oWKVvV7Z3TARVYvbXYvbDLW53bW4zVCh7V7Sm1mJiEjl6IbIIiIRpQAXEYmoSAR4LVyyb2Z9ZvYbM9thZi+Y2WeC+Z1m9pCZ7QoeO8KutdLMLGlmz5rZz4LpdWb2RLC/fxScJI8VM2s3swfM7EUz22lmb4/7vjazPw3+bW83s/vMrCGO+9rM7jWzY2a2fda88+5bK/qbYPu3mdk1l7Kuqg/wGrpkPwd8zt03AzcAnw628wvAw+5+GfBwMB03nwF2zpr+CvA1d98InATuDqWqxfV14BfufgXwForbH9t9bWargD8BBtz9KooDHz5KPPf1d4Hb58y70L59D3BZ8LcV+OalrKjqA5wauWTf3Q+7+zPB89MU/4deRXFbvxcs9j3gg6EUuEjMbDXwPuBbwbQBtwAPBIvEcZvbgJuBbwO4+4y7jxLzfU1x1FujmaWADHCYGO5rd/8tcGLO7Avt2w8A3/eifwLazWzFfNcVhQA/3yX7q0KqZUmYWT+wBXgCWObuh4OXjgDLwqprkfw18HmgEEx3AaPuXvoRyzju73XAMPCdoOvoW2bWRIz3tbsfBP4K2E8xuE8BTxP/fV1yoX1bVr5FIcBripk1A38PfNbdx2a/5sUxn7EZ92lmdwDH3P3psGtZYingGuCb7r4FGGdOd0kM93UHxdbmOmAl0MTruxlqQiX3bRQCvGYu2TezNMXw/oG7PxjMPlr6ShU8HgurvkVwE/B+M9tLsWvsFop9w+3B12yI5/4eAobc/Ylg+gGKgR7nfX0b8Kq7D7t7FniQ4v6P+74uudC+LSvfohDgNXHJftD3+21gp7t/ddZLPwXuCp7fBfzDUte2WNz9i+6+2t37Ke7XR9z9Y8BvgA8Hi8VqmwHc/QhwwMw2BbNuBXYQ431NsevkBjPLBP/WS9sc6309y4X27U+BPwpGo9wAnJrV1XJx7l71f8B7gZeBPcCXwq5nkbbxHRS/Vm0Dngv+3kuxT/hhYBfwa6Az7FoXafvfCfwseL4eeBLYDdwP1Idd3yJs71uBwWB//wToiPu+Bv4z8CKwHfgfQH0c9zVwH8V+/izFb1t3X2jfAkZxlN0e4HmKo3TmvS5dSi8iElFR6EIREZHzUICLiESUAlxEJKIU4CIiEaUAFxGJKAW4iEhEKcBFRCLq/wMqxa957LIVbgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "# 首先，我们定义一个随机训练数据\n",
    "np.random.seed(0)\n",
    "x = np.random.uniform(0, 2, 100)\n",
    "y = x * 3 + 1 + np.random.normal(0, 0.5, 100)\n",
    "\n",
    "# 将训练数据转换为 PyTorch Tensor\n",
    "x = torch.from_numpy(x).float().view(-1, 1)\n",
    "y = torch.from_numpy(y).float().view(-1, 1)\n",
    "\n",
    "# 然后，我们定义一个线性模型和损失函数\n",
    "model = torch.nn.Linear(1, 1)\n",
    "loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "# 接下来，我们使用 Adam 优化器来训练模型\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n",
    "\n",
    "# 初始化用于可视化训练过程的列表\n",
    "losses = []\n",
    "\n",
    "# 开始训练循环\n",
    "for i in range(100):\n",
    "    # 进行前向传递，计算损失\n",
    "    y_pred = model(x)\n",
    "    loss = loss_fn(y_pred, y)\n",
    "\n",
    "    # 将损失存储到列表中，以便我们可视化\n",
    "    losses.append(loss.item())\n",
    "\n",
    "    # 进行反向传递，更新参数\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "# 可视化训练过程\n",
    "plt.plot(losses)\n",
    "plt.ylim((0, 15))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**梗直哥提示：adam算法在使用过程中还是有不少细节需要注意的，也需要调参实战经验的总结。多动手多体会。当然，也可以欢迎选修进阶课程帮你加快这个过程《梗直哥的深度学习必修课：python实战》**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Next 6-11 学习率调节器](./6-11%20学习率调节器.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
